#pragma once
#include <mpi.h>
#include "enhance.h"
#include "cell.h"
#include "comm.h"
struct spring_com : enhance_t {
  int mol;
  real K;
  double sum[4];
  bool first;
  vec<real> com0;
  spring_com(){
    first = true;
  }
  void pre_force(cellgrid_t *grid, mpp_t *mpp) override {
    sum[0] = 0;
    sum[1] = 0;
    sum[2] = 0;
    sum[3] = 0;
    FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
      for (int i = 0; i < cell->natom; i ++){
        if (cell->mol[i] == mol) {
          sum[0] += cell->mass[i];
          sum[1] += cell->x[i].x * cell->mass[i];
          sum[2] += cell->x[i].y * cell->mass[i];
          sum[3] += cell->x[i].z * cell->mass[i];
        }
      }
    }
    MPI_Allreduce(MPI_IN_PLACE, sum, 4, MPI_DOUBLE, MPI_SUM, mpp->comm);
    if (first){
      com0 = vec<real>(sum[1] / sum[0], sum[2] / sum[0], sum[3] / sum[0]);
      first = true;
    }
  }
  void post_force(cellgrid_t *grid, mpp_t *mpp) override {
    vec<real> com(sum[1] / sum[0], sum[2] / sum[0], sum[3] / sum[0]);
    vec<real> d = com - com0;
    real r = d.norm();
    if (r > 0) {
      vec<real> f = -2.0 * K * d / r / sum[0];
      FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
        for (int i = 0; i < cell->natom; i ++){
          if (cell->mol[i] == mol){
            cell->f[i] += f;
          }
        }
      }
    }
  };
};
